{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_"
      },
      "source": [
        "# Haiku and Flax interop 🥂\n",
        "\n",
        "Utilities to move seamlessly between Haiku and Flax.\n",
        "\n",
        "## Flax inside Haiku\n",
        "\n",
        "Using a Flax module inside a `hk.transform` (or `hk.transform_with_state`) is\n",
        "straight forward.\n",
        "\n",
        "First construct an instance of your module, and then use `hkflax.lift` to\n",
        "\"lift\" (see [`hk.lift`]) the parameters and state from the Flax module into the\n",
        "Haiku transform.\n",
        "\n",
        "Example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b2hj0UXjhbHO",
        "outputId": "7914b929-d9a6-4992-d5a6-0cb6d93cde12"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "Array([[ 0.33030465, -1.3496182 ,  0.02847686, -1.6579462 , -0.9166192 ,\n",
              "         0.2883583 , -0.046898  ,  0.6414894 , -0.404975  , -2.1162813 ]],      dtype=float32)"
            ]
          },
          "execution_count": 1,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import haiku as hk\n",
        "import haiku.experimental.flax as hkflax\n",
        "import flax.linen as flax_nn\n",
        "\n",
        "def f(x):\n",
        "  mod = hkflax.lift(flax_nn.Dense(10), name='my_flax_module')\n",
        "  x = mod(x)\n",
        "  return x\n",
        "\n",
        "f = hk.transform(f)\n",
        "x = jnp.ones([1, 1])\n",
        "rng = jax.random.PRNGKey(42)\n",
        "params = f.init(rng, x)   # params contains the parameters for MyFlaxModule.\n",
        "f.apply(params, None, x)  # MyFlaxModule will be passed parameters from params."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mDMK25brhbHO"
      },
      "source": [
        "To use a stateful module simply swap `hk.transform` for\n",
        "`hk.transform_with_state`.\n",
        "\n",
        "## Haiku inside Flax\n",
        "\n",
        "There are two supported approaches for converting `Haiku` code to `Flax`. Both\n",
        "produce a Flax linen `nn.Module` which encapsulates the Haiku code and provides\n",
        "`init` and `apply` methods to create and use parameters and state.\n",
        "\n",
        "-   [Convert an `hk.Module` to `nn.Module`](#hk-Module).\n",
        "-   [Convert an `hk.transform` to `nn.Module`](#hk-transform).\n",
        "-   [Convert an `hk.transform_with_state` to `nn.Module`](#hk-transform).\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6cPBW1nRhbHO"
      },
      "source": [
        "### Converting `hk.Module` {#hk-Module}\n",
        "\n",
        "For stateless modules you simply need to construct the Flax module via\n",
        "`hkflax.Module.create`:\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uI14KWyShbHO"
      },
      "outputs": [],
      "source": [
        "mod = hkflax.Module.create(hk.Linear, 1)  # hk.Linear(1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KsWdu7yHhbHO"
      },
      "source": [
        "You can use this like a regular Flax `nn.Module` (because it is one!):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xIWWZ852hbHO"
      },
      "outputs": [],
      "source": [
        "rng = jax.random.PRNGKey(42)\n",
        "x = jnp.ones([1, 1])\n",
        "variables = mod.init(rng, x)\n",
        "out = mod.apply(variables, x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zMPIT8f2hbHO"
      },
      "source": [
        "For a stateful module like ResNet, you need to also handle output state, again\n",
        "this is the same as stateful Flax modules:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "B2gd3DNohbHO"
      },
      "outputs": [],
      "source": [
        "mod = hkflax.Module.create(hk.nets.ResNet50, 10)\n",
        "\n",
        "# Regular Flax code from here on:\n",
        "rng = jax.random.PRNGKey(42)\n",
        "x = jnp.ones([1, 224, 224, 3])\n",
        "variables = mod.init(rng, x, is_training=True)\n",
        "for _ in range(10):\n",
        "  out, state_out = mod.apply(variables, x, is_training=True,\n",
        "                             mutable=['state'])\n",
        "  variables = {**variables, **state_out}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JOoy3RqLhbHO"
      },
      "source": [
        "### Converting `hk.transform` or `hk.transform_with_state` {#hk-transform}\n",
        "\n",
        "`hkflax.Module` can be created from the result of `hk.transform` or\n",
        "`hk.transform_with_state` if you prefer:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6_OiX0AQhbHO"
      },
      "outputs": [],
      "source": [
        "def mlp(x):\n",
        "  x = hk.Linear(300)(x)\n",
        "  x = hk.Linear(100)(x)\n",
        "  x = hk.Linear(10)(x)\n",
        "  return x\n",
        "\n",
        "mlp = hk.transform(mlp)\n",
        "mlp = hkflax.Module(mlp)\n",
        "\n",
        "rng = jax.random.PRNGKey(42)\n",
        "x = jnp.ones([1, 28 * 28])\n",
        "variables = mlp.init(rng, x)\n",
        "out = mlp.apply(variables, x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TiUMVHHyhbHP"
      },
      "source": [
        "### Gotchas\n",
        "\n",
        "#### Initialization is different\n",
        "\n",
        "Flax and Haiku take different approaches to RNG key splitting. As such at the\n",
        "moment the parameters returned from `hkflax.Module(f).init` will differ from\n",
        "`hk.transform(f).init`.\n",
        "\n",
        "We have a route to support making Haiku transform match initialization of the\n",
        "Flax module, but there is not a path for the opposite direction at the moment.\n",
        "\n",
        "If aligning initialization across Haiku and Flax is important to you, we\n",
        "recommend using one of the libraries to create parameters, and then manipulate\n",
        "the params/state dictionary to match the other library as needed:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uI03vP-6hbHP"
      },
      "source": [
        "```python\n",
        "# Utilities.\n",
        "import haiku.data_structures as hkds\n",
        "\n",
        "make_flat = {f'{m}/{n}': w for m, n, w in hkds.traverse(d)}\n",
        "\n",
        "def make_nested(d):\n",
        "  out = {}\n",
        "  for k, w in d.items():\n",
        "    m, n = k.rsplit('/', 1)\n",
        "    out.setdefault(m, {})\n",
        "    out[m][n] = w\n",
        "  return out\n",
        "\n",
        "# The two modules here should be equivalent when run with Flax or Haiku.\n",
        "f = hk.transform_with_state(...)\n",
        "flax_mod = hkflax.Module(f)\n",
        "\n",
        "# Option 1: Convert Haiku initialized params/state to Flax.\n",
        "params, state = f.init(...)\n",
        "variables = {'params': make_flat(params), 'state': make_flat(state)}\n",
        "\n",
        "# Option 2: Convert Flax initialized variables to Haiku.\n",
        "variables = flax_mod.init(...)\n",
        "params = make_nested(variables.get('params', {}))\n",
        "state = make_nested(variables.get('state', {}))\n",
        "\n",
        "# The output of the Haiku transformed function or the Flax function should be\n",
        "# equivalent with either init.\n",
        "out, state = f.apply(params, state, ...)\n",
        "out, variables_out = flax_mod.apply(variables, ..., mutable=['state'])\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XrZ8te3shbHP"
      },
      "source": [
        "#### Multiple forward methods\n",
        "\n",
        "`hkflax.Module` only support `__call__` at the moment, please let us know if\n",
        "this is a blocker for you.\n",
        "\n",
        "[`hk.lift`]: https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.lift:"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {
        "build_target": "//learning/deepmind/dm_python:dm_notebook3",
        "kind": "private"
      },
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
